package com.socket.secure.filter;

import cn.hutool.core.date.SystemClock;
import cn.hutool.extra.servlet.ServletUtil;
import cn.hutool.http.Header;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentParser;
import com.socket.secure.constant.RequsetTemplate;
import com.socket.secure.constant.SecureConstant;
import com.socket.secure.constant.SecureProperties;
import com.socket.secure.event.entity.InitiatorEvent;
import com.socket.secure.exception.ExpiredRequestException;
import com.socket.secure.exception.InvalidRequestException;
import com.socket.secure.exception.RepeatedRequestException;
import com.socket.secure.filter.anno.RequestEnc;
import com.socket.secure.filter.validator.RepeatValidator;
import com.socket.secure.util.Assert;
import com.socket.secure.util.IPHash;
import com.socket.secure.util.MappingUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@Component
public final class SecureRequsetFilter {
    private static final Logger log = LoggerFactory.getLogger(SecureRequsetFilter.class);
    private ApplicationEventPublisher publisher;
    private SecureProperties properties;
    private RepeatValidator validator;

    /**
     * Decrypt and verify the request
     */
    public void filter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest _request = (HttpServletRequest) request;
        // Get the handler that executes this URI
        HandlerMethod handler = MappingUtil.getHandlerMethod(_request);
        if (handler == null) {
            chain.doFilter(request, response);
            return;
        }
        // Check the class tag first
        RequestEnc anno = handler.getBeanType().getAnnotation(RequestEnc.class);
        // check method tag if null
        if (anno == null) {
            anno = handler.getMethod().getAnnotation(RequestEnc.class);
        }
        // Decrypt request
        if (anno != null) {
            try {
                // check hash ip
                if (properties.isVerifyRequestRemote()) {
                    boolean checkHash = IPHash.checkHash(_request.getSession(), ServletUtil.getClientIP(_request));
                    Assert.isTrue(checkHash, RequsetTemplate.IP_ADDRESS_MISMATCH, InvalidRequestException::new);
                }
                boolean checkFile = properties.isVerifyFileSignature();
                SecureRequestWrapper wrapper = new SecureRequestWrapper(_request, checkFile);
                // Decryption request
                wrapper.decryptRequset(anno.sign());
                // Expired request validation
                long time = wrapper.getTimestamp();
                boolean expired = validator.isExpired(time, properties.getLinkValidTime());
                Assert.isFalse(expired, () -> new ExpiredRequestException(RequsetTemplate.EXPIRED_REQUEST, time, SystemClock.now()));
                // Repeat request validation
                boolean repeated = validator.isRepeated(time, wrapper.sign());
                Assert.isFalse(repeated, () -> new RepeatedRequestException(RequsetTemplate.REPEATED_REQUEST, time, SystemClock.now()));
                // Signature verification
                boolean signature = wrapper.compareSign();
                Assert.isTrue(signature, () -> new InvalidRequestException(RequsetTemplate.INVALID_REQUEST_SIGNATURE, wrapper.sign()));
                request = wrapper;
            } catch (InvalidRequestException | IllegalArgumentException e) {
                ((HttpServletResponse) response).setStatus(SecureConstant.VERIFY_FAILED_HTTP_ERROR_CODE);
                this.pushEvent(_request, handler, e.getMessage());
                log.warn(e.getMessage());
                return;
            }
        }
        chain.doFilter(request, response);
    }

    /**
     * Spring event push
     *
     * @param request {@link HttpServletRequest}
     * @param handler {@link HandlerMethod}
     * @param reason  Authentication failure reason
     */
    private void pushEvent(HttpServletRequest request, HandlerMethod handler, String reason) {
        InitiatorEvent event = new InitiatorEvent(publisher);
        UserAgent userAgent = UserAgentParser.parse(request.getHeader(Header.USER_AGENT.getValue()));
        event.setUserAgent(userAgent);
        event.setRemote(ServletUtil.getClientIP(request));
        event.setSession(request.getSession());
        event.setMethod(handler.getMethod());
        event.setController(handler.getBeanType());
        event.setReason(reason);
        publisher.publishEvent(event);
    }

    @Autowired
    private void setPublisher(ApplicationEventPublisher publisher) {
        this.publisher = publisher;
    }

    @Autowired
    private void setProperties(SecureProperties properties) {
        this.properties = properties;
    }

    @Autowired
    private void setValidator(RepeatValidator repeatValidator) {
        this.validator = repeatValidator;
    }
}
